package kafka_test

import (
	"testing"
	"time"

	"kyanos/agent/buffer"
	"kyanos/agent/protocol"
	"kyanos/agent/protocol/kafka"
	. "kyanos/agent/protocol/kafka/common"

	// . "kyanos/agent/protocol/kafka/decoder"

	"github.com/stretchr/testify/assert"
)

func PacketsEqual(lhs, rhs Packet) bool {
	if lhs.Msg != rhs.Msg {
		return false
	}
	if lhs.CorrelationID != rhs.CorrelationID {
		return false
	}
	return true
}

var kProduceRequest []byte = []byte{
	0x00, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x04, 0x00, 0x10, 0x63, 0x6f,
	0x6e, 0x73, 0x6f, 0x6c, 0x65, 0x2d, 0x70, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x65, 0x72, 0x00, 0x00,
	0x00, 0x01, 0x00, 0x00, 0x05, 0xdc, 0x02, 0x12, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x73, 0x74, 0x61,
	0x72, 0x74, 0x2d, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x02, 0x00, 0x00, 0x00, 0x00, 0x5b, 0x00,
	0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4e, 0xff, 0xff, 0xff, 0xff, 0x02,
	0xc0, 0xde, 0x91, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x7a, 0x1b, 0xc8,
	0x2d, 0xaa, 0x00, 0x00, 0x01, 0x7a, 0x1b, 0xc8, 0x2d, 0xaa, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0x00, 0x00, 0x00,
	0x01, 0x2c, 0x54, 0x68, 0x69, 0x73, 0x20, 0x69, 0x73, 0x20, 0x6d, 0x79, 0x20, 0x66, 0x69, 0x72,
	0x73, 0x74, 0x20, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x00, 0x00, 0x00, 0x00,
}

var kProduceResponse []byte = []byte{
	0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x04, 0x00, 0x02, 0x12, 0x71, 0x75, 0x69,
	0x63, 0x6b, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2d, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73,
	0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
	0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00,
	0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}

// APIKey: 3, APIVersion: 11
var kMetaDataRequest []byte = []byte{
	0x00, 0x00, 0x00, 0x1c, 0x00, 0x03, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0d, 0x61, 0x64,
	0x6d, 0x69, 0x6e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2d, 0x31, 0x00, 0x01, 0x01, 0x00, 0x00,
}
var kMetaDataResponse []byte = []byte{0x00, 0x00, 0x00, 0x3b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00,
	0x00, 0x00, 0x0a, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74, 0x00, 0x00, 0x23, 0x84,
	0x00, 0x00, 0x17, 0x5a, 0x65, 0x76, 0x76, 0x4e, 0x66, 0x47, 0x45, 0x52, 0x30, 0x4f, 0x73, 0x51,
	0x4d, 0x34, 0x77, 0x71, 0x48, 0x5f, 0x6f, 0x75, 0x77, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00,
}

// APIKey: 18, APIVersion: 3
var kAPIVersionRequest []byte = []byte{
	0x00, 0x00, 0x00, 0x31, 0x00, 0x12, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x0d,
	0x61, 0x64, 0x6d, 0x69, 0x6e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2d, 0x31, 0x00,
	0x12, 0x61, 0x70, 0x61, 0x63, 0x68, 0x65, 0x2d, 0x6b, 0x61, 0x66, 0x6b, 0x61, 0x2d,
	0x6a, 0x61, 0x76, 0x61, 0x06, 0x32, 0x2e, 0x38, 0x2e, 0x30, 0x00,
}
var kAPIVersionResponse []byte = []byte{
	0x00, 0x00, 0x01, 0x9e, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x39, 0x00, 0x00, 0x00, 0x00, 0x00,
	0x09, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x06, 0x00,
	0x00, 0x03, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05,
	0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x07, 0x00, 0x00,
	0x00, 0x03, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x07,
	0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00,
	0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x0e, 0x00,
	0x00, 0x00, 0x05, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
	0x04, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x03, 0x00,
	0x00, 0x13, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x15,
	0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x17, 0x00, 0x00,
	0x00, 0x04, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x03,
	0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
	0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x1d, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x1e, 0x00,
	0x00, 0x00, 0x02, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
	0x04, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x02, 0x00,
	0x00, 0x23, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x25,
	0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x27, 0x00, 0x00,
	0x00, 0x02, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x02,
	0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00,
	0x2c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2e, 0x00,
	0x00, 0x00, 0x00, 0x00, 0x00, 0x2f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00,
	0x01, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00,
	0x00, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x39,
	0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x00,
	0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
	0x00, 0x00,
}

func GenPacket(t *testing.T, msgType protocol.MessageType, rawPacket []byte, timestamp int, correlationID int) *Packet {
	var result *Packet
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)
	KafkaStreamParser.GetCorrelationIdMap()[int32(correlationID)] = struct{}{}
	packetView := rawPacket
	streamBuffer.Add(1, packetView, uint64(timestamp))
	parseState := KafkaStreamParser.ParseStream(streamBuffer, msgType)
	// assert.Equal(t, protocol.Success, parseState.ParseState)
	result = parseState.ParsedMessages[0].(*Packet)
	result.SetTimeStamp(uint64(timestamp))
	return result
}

var (
	KProduceReqPacket     = GenPacket(nil, protocol.Request, kProduceRequest, 0, 4)
	KProduceRespPacket    = GenPacket(nil, protocol.Response, kProduceResponse, 1, 4)
	KMetaDataReqPacket    = GenPacket(nil, protocol.Request, kMetaDataRequest, 2, 1)
	KMetaDataRespPacket   = GenPacket(nil, protocol.Response, kMetaDataResponse, 3, 1)
	KAPIVersionReqPacket  = GenPacket(nil, protocol.Request, kAPIVersionRequest, 4, 2)
	KAPIVersionRespPacket = GenPacket(nil, protocol.Response, kAPIVersionResponse, 5, 2)
)

func TestKafkaParserBasics(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	produceFrameView := []byte(kProduceRequest)
	streamBuffer.Add(1, produceFrameView, uint64(time.Now().Nanosecond()))
	parseState := KafkaStreamParser.ParseStream(streamBuffer, protocol.Request)
	_, containsCorrelationid := KafkaStreamParser.GetCorrelationIdMap()[4]
	assert.True(t, parseState.ParseState == protocol.Success)
	assert.True(t, containsCorrelationid)

	KafkaStreamParser = kafka.NewKafkaStreamParser()
	streamBuffer.Clear()
	shortProduceFrameView := produceFrameView[:KMinReqPacketLength-1]
	streamBuffer.Add(1, shortProduceFrameView, uint64(time.Now().Nanosecond()))
	parseState = KafkaStreamParser.ParseStream(streamBuffer, protocol.Request)
	assert.True(t, parseState.ParseState == protocol.NeedsMoreData)
	assert.Empty(t, KafkaStreamParser.GetCorrelationIdMap())
}

func TestKafkaParserParseMultipleRequests(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	// Add multiple requests to the buffer
	request1 := []byte(kProduceRequest)
	request2 := []byte(kMetaDataRequest)
	streamBuffer.Add(1, request1, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(request1))+1, request2, uint64(time.Now().Nanosecond()+1))

	parseState := KafkaStreamParser.ParseStream(streamBuffer, protocol.Request)
	_, containsCorrelationID1 := KafkaStreamParser.GetCorrelationIdMap()[4]

	assert.True(t, parseState.ParseState == protocol.Success)
	assert.True(t, containsCorrelationID1)
	streamBuffer.RemovePrefix(parseState.ReadBytes)
	parseState = KafkaStreamParser.ParseStream(streamBuffer, protocol.Request)
	_, containsCorrelationID2 := KafkaStreamParser.GetCorrelationIdMap()[1]
	assert.True(t, parseState.ParseState == protocol.Success)
	assert.True(t, containsCorrelationID2)
}

func TestKafkaParserParseMultipleResponses(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	// Add multiple requests to the buffer
	request1 := []byte(kProduceResponse)
	request2 := []byte(kMetaDataResponse)
	streamBuffer.Add(1, request1, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(request1))+1, request2, uint64(time.Now().Nanosecond()+1))

	parseState := KafkaStreamParser.ParseStream(streamBuffer, protocol.Response)

	assert.True(t, parseState.ParseState == protocol.Success)
	packet := parseState.ParsedMessages[0].(*Packet)
	assert.True(t, packet.CorrelationID == 4)

	streamBuffer.RemovePrefix(parseState.ReadBytes)
	parseState = KafkaStreamParser.ParseStream(streamBuffer, protocol.Response)
	assert.True(t, parseState.ParseState == protocol.Success)
	packet = parseState.ParsedMessages[0].(*Packet)
	assert.True(t, packet.CorrelationID == 1)
}

func TestKafkaParserParseIncompleteRequest(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	// Add multiple requests to the buffer
	request1 := []byte(kProduceRequest)
	request1 = request1[:len(request1)-1]
	streamBuffer.Add(1, request1, uint64(time.Now().Nanosecond()))

	parseState := KafkaStreamParser.ParseStream(streamBuffer, protocol.Request)

	assert.True(t, parseState.ParseState == protocol.NeedsMoreData)
	assert.Empty(t, KafkaStreamParser.GetCorrelationIdMap())
	assert.Empty(t, parseState.ParsedMessages)
}

func TestKafkaParserParseInvalidInput(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	// Add multiple requests to the buffer
	request1 := []byte("\x00\x00\x18\x00\x03SELECT name FROM users;")
	streamBuffer.Add(1, request1, uint64(time.Now().Nanosecond()))

	parseState := KafkaStreamParser.ParseStream(streamBuffer, protocol.Request)

	assert.True(t, parseState.ParseState == protocol.Invalid)
	assert.Empty(t, KafkaStreamParser.GetCorrelationIdMap())
	assert.Empty(t, parseState.ParsedMessages)
}

func TestKafkaParserFindReqBoundaryAligned(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	request1 := []byte(kProduceRequest)
	request2 := []byte(kMetaDataRequest)
	streamBuffer.Add(1, request1, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(request1)+1), request2, uint64(time.Now().Nanosecond()+1))

	boundary := KafkaStreamParser.FindBoundary(streamBuffer, protocol.Request, 0)
	assert.Equal(t, boundary, 0)
}

func TestKafkaParserFindReqBoundaryUnAligned(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	// Add multiple requests to the buffer
	garbage := []byte("some garbage")
	request1 := []byte(kProduceRequest)
	request2 := []byte(kMetaDataRequest)
	streamBuffer.Add(1, garbage, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(garbage)+1), request1, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(garbage)+len(request1))+1, request2, uint64(time.Now().Nanosecond()+1))

	boundary := KafkaStreamParser.FindBoundary(streamBuffer, protocol.Request, 0)
	assert.Equal(t, uint64(boundary), uint64(len(garbage)))
}

func TestKafkaParserFindRespBoundaryAligned(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	request1 := []byte(kProduceResponse)
	request2 := []byte(kMetaDataResponse)
	streamBuffer.Add(1, request1, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(request1)+1), request2, uint64(time.Now().Nanosecond()+1))

	KafkaStreamParser.GetCorrelationIdMap()[4] = struct{}{}
	KafkaStreamParser.GetCorrelationIdMap()[1] = struct{}{}
	boundary := KafkaStreamParser.FindBoundary(streamBuffer, protocol.Response, 0)
	assert.Equal(t, boundary, 0)
}

func TestKafkaParserFindRespBoundaryUnAligned(t *testing.T) {
	KafkaStreamParser := kafka.NewKafkaStreamParser()
	streamBuffer := buffer.New(10000)

	// Add multiple requests to the buffer
	garbage := []byte("some garbage")
	request1 := []byte(kProduceResponse)
	request2 := []byte(kMetaDataResponse)
	streamBuffer.Add(1, garbage, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(garbage)+1), request1, uint64(time.Now().Nanosecond()))
	streamBuffer.Add(uint64(len(garbage)+len(request1))+1, request2, uint64(time.Now().Nanosecond()+1))

	KafkaStreamParser.GetCorrelationIdMap()[4] = struct{}{}
	KafkaStreamParser.GetCorrelationIdMap()[1] = struct{}{}
	boundary := KafkaStreamParser.FindBoundary(streamBuffer, protocol.Response, 0)
	assert.Equal(t, boundary, len(garbage))

	// If the correlation_id of produce response (i.e. 4) is not seen, this should skip over it.
	delete(KafkaStreamParser.GetCorrelationIdMap(), 4)
	boundary = KafkaStreamParser.FindBoundary(streamBuffer, protocol.Response, 0)
	assert.Equal(t, boundary, len(garbage)+len(request1))
}

func TestKafkaParserMatch(t *testing.T) {
	reqStreams := make(map[protocol.StreamId]*protocol.ParsedMessageQueue)
	respStreams := make(map[protocol.StreamId]*protocol.ParsedMessageQueue)
	KafkaStreamParser := kafka.NewKafkaStreamParser()

	records := KafkaStreamParser.Match(reqStreams, respStreams)
	assert.Empty(t, records)

	reqStream := &protocol.ParsedMessageQueue{}
	reqStreams[0] = reqStream
	*reqStream = append(*reqStream, KProduceReqPacket)
	records = KafkaStreamParser.Match(reqStreams, respStreams)
	assert.Empty(t, records)
	assert.Equal(t, 1, len(*reqStream))

	respStream := &protocol.ParsedMessageQueue{}
	respStreams[0] = respStream
	*respStream = append(*respStream, KProduceRespPacket)
	records = KafkaStreamParser.Match(reqStreams, respStreams)
	assert.Equal(t, 1, len(records))
	assert.Equal(t, 0, len(*reqStream))
	assert.Equal(t, 0, len(*respStream))
}
